Generative Adversarial Imitation Learning (GAIL) — low-level PyTorch#
GAIL is an imitation learning algorithm: it learns a policy (\pi_\theta(a\mid s)) from expert demonstrations without access to the expert’s reward function.
The core idea is adversarial training:
a discriminator (D_\phi(s,a)) tries to tell expert ((s,a)) pairs apart from policy-generated ((s,a)) pairs
the policy is trained to fool the discriminator, using a reward derived from (D_\phi)
This notebook implements a small but complete GAIL loop from scratch in PyTorch:
a toy 2D navigation environment (no Gym dependency)
a hand-coded expert to generate demonstrations
a discriminator network (D_\phi)
an actor-critic policy (\pi_\theta) trained with PPO using the discriminator reward
Plotly curves for discriminator loss, policy learning, and episodic rewards
Learning goals#
By the end you should be able to:
write down the GAIL min–max objective and explain the GAN analogy
implement a discriminator over ((s,a)) and train it with cross-entropy
turn discriminator outputs into a reward signal for RL
implement the PPO update (clipped objective + value loss + entropy bonus)
monitor training with Plotly: discriminator loss + episodic return
Notebook roadmap#
GAIL objective + adversarial training equations
A tiny offline-friendly environment + expert demonstrations
Low-level PyTorch: policy/value networks + discriminator
Training loop (alternate discriminator updates and policy PPO updates)
Plotly diagnostics: discriminator loss, policy learning, episodic rewards
Stable-Baselines GAIL notes + hyperparameters (end)
import time
import warnings
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
import torch
import torch.nn as nn
import torch.nn.functional as F
pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
# Some environments emit a CUDA-availability warning even when using CPU tensors.
warnings.filterwarnings("ignore", message=r"CUDA initialization:.*")
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.set_num_threads(1)
DEVICE = torch.device("cpu")
np.set_printoptions(precision=4, suppress=True)
# --- Run configuration ---
FAST_RUN = True # set False for longer training
# Environment
N_ENVS = 32 if FAST_RUN else 128
MAX_STEPS = 60
STEP_SIZE = 0.20
NOISE_STD = 0.00
GOAL_RADIUS = 0.12
# Expert dataset
N_EXPERT_EPISODES = 300 if FAST_RUN else 1500
# GAIL + PPO
ITERATIONS = 25 if FAST_RUN else 200
STEPS_PER_ITER = 128 if FAST_RUN else 512
GAMMA = 0.99
LAMBDA_GAE = 0.95
# Discriminator updates
D_LR = 3e-4
D_EPOCHS = 2 if FAST_RUN else 5
D_BATCH_SIZE = 512
# PPO updates
PI_LR = 3e-4
PPO_EPOCHS = 4 if FAST_RUN else 10
PPO_BATCH_SIZE = 1024
CLIP_EPS = 0.2
VF_COEF = 0.5
ENT_COEF = 0.01
# Eval
EVAL_EVERY = 5
EVAL_EPISODES = 200 if FAST_RUN else 1000
1) GAIL: adversarial training objective (equations)#
Setup#
You have expert demonstrations (state-action pairs) sampled from an expert policy (\pi_E):
[ (s,a) \sim \pi_E. ]
You want to learn a policy (\pi_\theta) that induces (approximately) the same occupancy measure as the expert.
Discriminator objective#
GAIL uses a discriminator (D_\phi(s,a)\in(0,1)) that outputs the probability that a ((s,a)) pair came from the expert. It is trained like a GAN discriminator:
[ \max_{\phi}; \mathbb{E}{(s,a)\sim \pi_E}[\log D\phi(s,a)]
\mathbb{E}{(s,a)\sim \pi\theta}[\log(1 - D_\phi(s,a))]. ]
Policy (generator) objective#
The policy plays the role of the GAN generator: it tries to produce ((s,a)) that the discriminator labels as expert. A common generator loss is:
[ \min_{\theta}; \mathbb{E}{(s,a)\sim \pi\theta}[\log(1 - D_\phi(s,a))] - \lambda,H(\pi_\theta), ]
where (H(\pi_\theta)) is the policy entropy (encourages exploration).
Turning the discriminator into a reward#
To train (\pi_\theta) with RL, we convert discriminator outputs into a reward:
[ \hat r_\phi(s,a) = -\log(1 - D_\phi(s,a)). ]
If the discriminator uses a logit (f_\phi(s,a)) (so (D_\phi = \sigma(f_\phi))), this reward has a numerically-stable form:
[ \hat r_\phi(s,a) = -\log(\sigma(-f_\phi(s,a))) = \mathrm{softplus}(f_\phi(s,a)). ]
Policy optimization (we’ll use PPO)#
We’ll optimize (\pi_\theta) with PPO. With advantage estimates (\hat A_t), PPO’s clipped objective is:
[ L^{\text{CLIP}}(\theta) = \mathbb{E}_t\Big[\min\big(r_t(\theta)\hat A_t,;\mathrm{clip}(r_t(\theta),1-\epsilon,1+\epsilon)\hat A_t\big)\Big], ]
where (r_t(\theta)=\exp(\log\pi_\theta(a_t\mid s_t)-\log\pi_{\theta_{\text{old}}}(a_t\mid s_t))).
2) A tiny environment (no downloads, no Gym)#
We’ll use a simple 2D point navigation task:
observation (s = (x,y)\in[-1,1]^2)
5 discrete actions: stay / up / down / left / right
start position is random
goal is the origin
episode ends when the point enters a goal radius or hits a step limit
We’ll generate expert demonstrations using a greedy hand-coded expert that always moves along the largest coordinate toward 0.
class VectorPointNav2D:
def __init__(
self,
n_envs: int,
max_steps: int = 60,
step_size: float = 0.20,
noise_std: float = 0.00,
goal: tuple[float, float] = (0.0, 0.0),
goal_radius: float = 0.12,
seed: int = 0,
):
self.n_envs = int(n_envs)
self.max_steps = int(max_steps)
self.step_size = float(step_size)
self.noise_std = float(noise_std)
self.goal = np.array(goal, dtype=np.float32)
self.goal_radius = float(goal_radius)
self.rng = np.random.default_rng(seed)
self.pos = np.zeros((self.n_envs, 2), dtype=np.float32)
self.t = np.zeros(self.n_envs, dtype=np.int32)
@property
def obs_dim(self) -> int:
return 2
@property
def n_actions(self) -> int:
# 0 stay, 1 up, 2 down, 3 left, 4 right
return 5
def reset(self) -> np.ndarray:
self.t[:] = 0
self.pos[:] = self.rng.uniform(low=-1.0, high=1.0, size=(self.n_envs, 2)).astype(np.float32)
return self.pos.copy()
def reset_done(self, done_mask: np.ndarray) -> None:
idx = np.where(done_mask)[0]
if len(idx) == 0:
return
self.t[idx] = 0
self.pos[idx] = self.rng.uniform(low=-1.0, high=1.0, size=(len(idx), 2)).astype(np.float32)
def step(self, actions: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
actions = np.asarray(actions, dtype=np.int64)
assert actions.shape == (self.n_envs,)
move = np.zeros((self.n_envs, 2), dtype=np.float32)
move[actions == 1, 1] = 1.0
move[actions == 2, 1] = -1.0
move[actions == 3, 0] = -1.0
move[actions == 4, 0] = 1.0
noise = self.rng.normal(loc=0.0, scale=self.noise_std, size=(self.n_envs, 2)).astype(np.float32)
self.pos = np.clip(self.pos + self.step_size * move + noise, -1.0, 1.0)
self.t += 1
dist = np.linalg.norm(self.pos - self.goal[None, :], axis=1)
success = dist < self.goal_radius
done = success | (self.t >= self.max_steps)
# True environment reward (for monitoring): small time penalty, big success bonus
reward = -0.01 * np.ones(self.n_envs, dtype=np.float32)
reward = reward + success.astype(np.float32) * 1.0
info = {
"dist": dist.astype(np.float32),
"success": success.astype(np.bool_),
}
return self.pos.copy(), reward, done.astype(np.bool_), info
def expert_policy(obs: np.ndarray) -> np.ndarray:
# Greedy expert: move along the largest coordinate toward 0.
x = obs[:, 0]
y = obs[:, 1]
ax = np.abs(x)
ay = np.abs(y)
actions = np.zeros(len(obs), dtype=np.int64)
choose_x = ax >= ay
actions[choose_x & (x > 0)] = 3 # left
actions[choose_x & (x < 0)] = 4 # right
actions[(~choose_x) & (y > 0)] = 2 # down
actions[(~choose_x) & (y < 0)] = 1 # up
return actions
# Quick look at one expert trajectory
env_one = VectorPointNav2D(
n_envs=1,
max_steps=MAX_STEPS,
step_size=STEP_SIZE,
noise_std=NOISE_STD,
goal_radius=GOAL_RADIUS,
seed=SEED,
)
obs = env_one.reset()
traj = [obs[0].copy()]
actions = []
rewards = []
done = np.array([False])
while not done[0]:
a = expert_policy(obs)[0]
obs, r, done, info = env_one.step(np.array([a]))
traj.append(obs[0].copy())
actions.append(int(a))
rewards.append(float(r[0]))
traj = np.stack(traj)
fig = go.Figure()
fig.add_trace(go.Scatter(x=traj[:, 0], y=traj[:, 1], mode="lines+markers", name="expert"))
fig.add_trace(go.Scatter(x=[0], y=[0], mode="markers", marker=dict(size=12, symbol="x"), name="goal"))
fig.update_layout(title="One expert trajectory (2D point → origin)", xaxis_title="x", yaxis_title="y")
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.show()
print("steps:", len(actions), "episode_return:", sum(rewards), "success:", info["success"][0])
steps: 4 episode_return: 0.9600000102072954 success: True
3) Expert demonstrations dataset#
GAIL trains the discriminator on expert and policy ((s,a)) samples. We’ll generate a dataset of expert state-action pairs by running the expert in the environment.
def collect_expert_pairs(n_episodes: int, seed: int) -> tuple[np.ndarray, np.ndarray]:
env = VectorPointNav2D(
n_envs=1,
max_steps=MAX_STEPS,
step_size=STEP_SIZE,
noise_std=NOISE_STD,
goal_radius=GOAL_RADIUS,
seed=seed,
)
obs_list: list[np.ndarray] = []
act_list: list[int] = []
for _ in range(int(n_episodes)):
obs = env.reset()
done = np.array([False])
while not done[0]:
a = int(expert_policy(obs)[0])
obs_list.append(obs[0].copy())
act_list.append(a)
obs, r, done, info = env.step(np.array([a]))
expert_obs = np.stack(obs_list).astype(np.float32)
expert_acts = np.array(act_list, dtype=np.int64)
return expert_obs, expert_acts
expert_obs, expert_acts = collect_expert_pairs(N_EXPERT_EPISODES, seed=SEED)
print("expert_obs", expert_obs.shape, "expert_acts", expert_acts.shape)
fig = px.histogram(x=expert_acts, nbins=5, title="Expert action histogram")
fig.update_layout(xaxis_title="action (0 stay, 1 up, 2 down, 3 left, 4 right)", yaxis_title="count")
fig.show()
expert_obs (2068, 2) expert_acts (2068,)
4) Low-level PyTorch: policy/value network and discriminator#
We’ll use:
Policy/value: a small shared MLP with two heads
policy head outputs categorical logits over 5 actions
value head outputs (V_\theta(s))
Discriminator: an MLP over ((s, \text{one-hot}(a))) returning a logit (f_\phi(s,a))
def one_hot(actions: torch.Tensor, n_actions: int) -> torch.Tensor:
return F.one_hot(actions.long(), num_classes=n_actions).float()
class ActorCritic(nn.Module):
def __init__(self, obs_dim: int, n_actions: int, hidden_sizes: tuple[int, ...] = (64, 64)):
super().__init__()
layers: list[nn.Module] = []
in_dim = obs_dim
for h in hidden_sizes:
layers += [nn.Linear(in_dim, h), nn.Tanh()]
in_dim = h
self.shared = nn.Sequential(*layers)
self.pi = nn.Linear(in_dim, n_actions)
self.v = nn.Linear(in_dim, 1)
def forward(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = self.shared(obs)
logits = self.pi(x)
value = self.v(x).squeeze(-1)
return logits, value
def get_action_and_value(
self,
obs: torch.Tensor,
action: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
logits, value = self.forward(obs)
dist = torch.distributions.Categorical(logits=logits)
if action is None:
action = dist.sample()
logp = dist.log_prob(action)
entropy = dist.entropy()
return action, logp, entropy, value
class Discriminator(nn.Module):
def __init__(self, obs_dim: int, n_actions: int, hidden_sizes: tuple[int, ...] = (128, 128)):
super().__init__()
in_dim = obs_dim + n_actions
layers: list[nn.Module] = []
for h in hidden_sizes:
layers += [nn.Linear(in_dim, h), nn.Tanh()]
in_dim = h
layers += [nn.Linear(in_dim, 1)]
self.net = nn.Sequential(*layers)
self.n_actions = n_actions
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
x = torch.cat([obs, one_hot(actions, self.n_actions)], dim=-1)
return self.net(x).squeeze(-1) # logits
policy = ActorCritic(obs_dim=2, n_actions=5).to(DEVICE)
disc = Discriminator(obs_dim=2, n_actions=5).to(DEVICE)
pi_opt = torch.optim.Adam(policy.parameters(), lr=PI_LR)
d_opt = torch.optim.Adam(disc.parameters(), lr=D_LR)
5) Rollouts, GAE, discriminator update, PPO update#
We’ll collect policy rollouts from a vectorized environment, then alternate:
Discriminator update(s) using expert pairs and current policy pairs
Policy PPO update(s) using discriminator-derived rewards
We’ll use GAE((\gamma,\lambda)) for advantages.
def rollout(env: VectorPointNav2D, policy: ActorCritic, n_steps: int) -> dict:
obs = env.reset()
n_envs = env.n_envs
obs_buf = []
act_buf = []
logp_buf = []
val_buf = []
done_buf = []
true_r_buf = []
ep_returns = np.zeros(n_envs, dtype=np.float32)
completed_returns: list[float] = []
completed_success: list[bool] = []
for _ in range(int(n_steps)):
obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
action, logp, entropy, value = policy.get_action_and_value(obs_t)
act_np = action.cpu().numpy()
next_obs, true_r, done, info = env.step(act_np)
obs_buf.append(obs.copy())
act_buf.append(act_np.copy())
logp_buf.append(logp.cpu().numpy())
val_buf.append(value.cpu().numpy())
done_buf.append(done.copy())
true_r_buf.append(true_r.copy())
ep_returns += true_r
for i in range(n_envs):
if done[i]:
completed_returns.append(float(ep_returns[i]))
completed_success.append(bool(info["success"][i]))
ep_returns[i] = 0.0
env.reset_done(done)
obs = env.pos.copy()
with torch.no_grad():
_, last_values = policy.forward(torch.tensor(obs, dtype=torch.float32, device=DEVICE))
return {
"obs": np.asarray(obs_buf, dtype=np.float32),
"actions": np.asarray(act_buf, dtype=np.int64),
"logp": np.asarray(logp_buf, dtype=np.float32),
"values": np.asarray(val_buf, dtype=np.float32),
"dones": np.asarray(done_buf, dtype=np.bool_),
"true_rewards": np.asarray(true_r_buf, dtype=np.float32),
"last_values": last_values.cpu().numpy().astype(np.float32),
"completed_returns": completed_returns,
"completed_success": completed_success,
}
def compute_gae(
rewards: np.ndarray,
values: np.ndarray,
dones: np.ndarray,
last_values: np.ndarray,
gamma: float,
lam: float,
) -> tuple[np.ndarray, np.ndarray]:
# GAE-Lambda. Shapes: rewards/values/dones are (T, N). last_values is (N,).
T, N = rewards.shape
adv = np.zeros((T, N), dtype=np.float32)
last_adv = np.zeros(N, dtype=np.float32)
next_values = last_values.astype(np.float32)
for t in reversed(range(T)):
mask = 1.0 - dones[t].astype(np.float32)
delta = rewards[t] + gamma * next_values * mask - values[t]
last_adv = delta + gamma * lam * mask * last_adv
adv[t] = last_adv
next_values = values[t]
returns = adv + values
return adv, returns
def train_discriminator(
disc: Discriminator,
opt: torch.optim.Optimizer,
expert_obs: np.ndarray,
expert_acts: np.ndarray,
gen_obs: np.ndarray,
gen_acts: np.ndarray,
epochs: int,
batch_size: int,
) -> float:
# BCE discriminator update. expert label=1, generator label=0.
disc.train()
n_gen = len(gen_obs)
n_exp = len(expert_obs)
n = min(n_gen, n_exp)
idx_g = np.random.randint(0, n_gen, size=n)
idx_e = np.random.randint(0, n_exp, size=n)
g_obs = torch.tensor(gen_obs[idx_g], dtype=torch.float32, device=DEVICE)
g_act = torch.tensor(gen_acts[idx_g], dtype=torch.int64, device=DEVICE)
e_obs = torch.tensor(expert_obs[idx_e], dtype=torch.float32, device=DEVICE)
e_act = torch.tensor(expert_acts[idx_e], dtype=torch.int64, device=DEVICE)
losses: list[float] = []
for _ in range(int(epochs)):
perm = torch.randperm(n, device=DEVICE)
for start in range(0, n, int(batch_size)):
mb = perm[start : start + int(batch_size)]
logits_g = disc(g_obs[mb], g_act[mb])
logits_e = disc(e_obs[mb], e_act[mb])
loss_g = F.binary_cross_entropy_with_logits(logits_g, torch.zeros_like(logits_g))
loss_e = F.binary_cross_entropy_with_logits(logits_e, torch.ones_like(logits_e))
loss = loss_g + loss_e
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
losses.append(float(loss.detach().cpu()))
return float(np.mean(losses))
def ppo_update(
policy: ActorCritic,
opt: torch.optim.Optimizer,
obs: np.ndarray,
actions: np.ndarray,
old_logp: np.ndarray,
advantages: np.ndarray,
returns: np.ndarray,
clip_eps: float,
vf_coef: float,
ent_coef: float,
epochs: int,
batch_size: int,
) -> dict:
policy.train()
n = len(obs)
obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
act_t = torch.tensor(actions, dtype=torch.int64, device=DEVICE)
old_logp_t = torch.tensor(old_logp, dtype=torch.float32, device=DEVICE)
adv_t = torch.tensor(advantages, dtype=torch.float32, device=DEVICE)
ret_t = torch.tensor(returns, dtype=torch.float32, device=DEVICE)
adv_t = (adv_t - adv_t.mean()) / (adv_t.std() + 1e-8)
total_losses: list[float] = []
policy_losses: list[float] = []
value_losses: list[float] = []
entropies: list[float] = []
approx_kls: list[float] = []
for _ in range(int(epochs)):
perm = torch.randperm(n, device=DEVICE)
for start in range(0, n, int(batch_size)):
mb = perm[start : start + int(batch_size)]
_, logp, entropy, value = policy.get_action_and_value(obs_t[mb], act_t[mb])
ratio = torch.exp(logp - old_logp_t[mb])
pg1 = ratio * adv_t[mb]
pg2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv_t[mb]
policy_loss = -torch.mean(torch.minimum(pg1, pg2))
value_loss = F.mse_loss(value, ret_t[mb])
entropy_bonus = torch.mean(entropy)
loss = policy_loss + vf_coef * value_loss - ent_coef * entropy_bonus
opt.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
opt.step()
approx_kl = torch.mean(old_logp_t[mb] - logp).detach().cpu().item()
total_losses.append(float(loss.detach().cpu()))
policy_losses.append(float(policy_loss.detach().cpu()))
value_losses.append(float(value_loss.detach().cpu()))
entropies.append(float(entropy_bonus.detach().cpu()))
approx_kls.append(float(approx_kl))
return {
"loss": float(np.mean(total_losses)),
"policy_loss": float(np.mean(policy_losses)),
"value_loss": float(np.mean(value_losses)),
"entropy": float(np.mean(entropies)),
"approx_kl": float(np.mean(approx_kls)),
}
6) Train GAIL (alternate D and PPO updates)#
We’ll track:
discriminator loss
PPO diagnostics (policy/value/entropy/KL)
episodic returns from the true environment reward (monitoring only)
evaluation return + success rate
def gail_reward_from_logits(logits: torch.Tensor) -> torch.Tensor:
# r = -log(1 - sigmoid(logits)) = softplus(logits)
return F.softplus(logits)
def evaluate_policy(policy: ActorCritic, seed: int, n_episodes: int) -> dict:
env = VectorPointNav2D(
n_envs=1,
max_steps=MAX_STEPS,
step_size=STEP_SIZE,
noise_std=NOISE_STD,
goal_radius=GOAL_RADIUS,
seed=seed,
)
returns: list[float] = []
successes: list[bool] = []
steps: list[int] = []
for _ in range(int(n_episodes)):
obs = env.reset()
done = np.array([False])
ep_return = 0.0
ep_steps = 0
ep_success = False
while not done[0]:
obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
logits, _ = policy.forward(obs_t)
action = torch.argmax(logits, dim=-1)
obs, r, done, info = env.step(action.cpu().numpy())
ep_return += float(r[0])
ep_steps += 1
ep_success = bool(info["success"][0])
returns.append(ep_return)
successes.append(ep_success)
steps.append(ep_steps)
return {
"return_mean": float(np.mean(returns)),
"return_std": float(np.std(returns)),
"success_rate": float(np.mean(successes)),
"steps_mean": float(np.mean(steps)),
}
def evaluate_expert(seed: int, n_episodes: int) -> dict:
env = VectorPointNav2D(
n_envs=1,
max_steps=MAX_STEPS,
step_size=STEP_SIZE,
noise_std=NOISE_STD,
goal_radius=GOAL_RADIUS,
seed=seed,
)
returns: list[float] = []
successes: list[bool] = []
steps: list[int] = []
for _ in range(int(n_episodes)):
obs = env.reset()
done = np.array([False])
ep_return = 0.0
ep_steps = 0
ep_success = False
while not done[0]:
a = int(expert_policy(obs)[0])
obs, r, done, info = env.step(np.array([a]))
ep_return += float(r[0])
ep_steps += 1
ep_success = bool(info["success"][0])
returns.append(ep_return)
successes.append(ep_success)
steps.append(ep_steps)
return {
"return_mean": float(np.mean(returns)),
"return_std": float(np.std(returns)),
"success_rate": float(np.mean(successes)),
"steps_mean": float(np.mean(steps)),
}
env = VectorPointNav2D(
n_envs=N_ENVS,
max_steps=MAX_STEPS,
step_size=STEP_SIZE,
noise_std=NOISE_STD,
goal_radius=GOAL_RADIUS,
seed=SEED,
)
# Baseline: how good is the expert on this environment reward?
expert_eval = evaluate_expert(seed=SEED + 123, n_episodes=EVAL_EPISODES)
expert_eval
{'return_mean': 0.8834500104933977,
'return_std': 0.322598197569875,
'success_rate': 0.955,
'steps_mean': 7.155}
disc_loss_hist: list[float] = []
ppo_loss_hist: list[float] = []
ppo_policy_loss_hist: list[float] = []
ppo_value_loss_hist: list[float] = []
ppo_entropy_hist: list[float] = []
ppo_kl_hist: list[float] = []
train_ep_returns: list[float] = []
train_ep_success: list[bool] = []
train_ep_iter: list[int] = []
eval_iters: list[int] = []
eval_return_mean: list[float] = []
eval_success_rate: list[float] = []
start = time.time()
for it in range(int(ITERATIONS)):
data = rollout(env, policy, n_steps=STEPS_PER_ITER)
# Flatten rollout buffers for discriminator/policy updates
obs = data["obs"].reshape(-1, 2)
acts = data["actions"].reshape(-1)
old_logp = data["logp"].reshape(-1)
# 1) Discriminator update
dloss = train_discriminator(
disc=disc,
opt=d_opt,
expert_obs=expert_obs,
expert_acts=expert_acts,
gen_obs=obs,
gen_acts=acts,
epochs=D_EPOCHS,
batch_size=D_BATCH_SIZE,
)
# 2) Compute discriminator reward for the policy rollout
disc.eval()
with torch.no_grad():
logits = disc(
torch.tensor(obs, dtype=torch.float32, device=DEVICE),
torch.tensor(acts, dtype=torch.int64, device=DEVICE),
)
gail_rewards = gail_reward_from_logits(logits).cpu().numpy().reshape(STEPS_PER_ITER, N_ENVS)
# 3) PPO update using GAE on the discriminator reward
adv, rets = compute_gae(
rewards=gail_rewards,
values=data["values"],
dones=data["dones"],
last_values=data["last_values"],
gamma=GAMMA,
lam=LAMBDA_GAE,
)
ppo_stats = ppo_update(
policy=policy,
opt=pi_opt,
obs=obs,
actions=acts,
old_logp=old_logp,
advantages=adv.reshape(-1),
returns=rets.reshape(-1),
clip_eps=CLIP_EPS,
vf_coef=VF_COEF,
ent_coef=ENT_COEF,
epochs=PPO_EPOCHS,
batch_size=PPO_BATCH_SIZE,
)
disc_loss_hist.append(dloss)
ppo_loss_hist.append(ppo_stats["loss"])
ppo_policy_loss_hist.append(ppo_stats["policy_loss"])
ppo_value_loss_hist.append(ppo_stats["value_loss"])
ppo_entropy_hist.append(ppo_stats["entropy"])
ppo_kl_hist.append(ppo_stats["approx_kl"])
# Record true episodic returns completed during this iteration
for r, s in zip(data["completed_returns"], data["completed_success"]):
train_ep_returns.append(float(r))
train_ep_success.append(bool(s))
train_ep_iter.append(it)
# Evaluate periodically
if (it + 1) % int(EVAL_EVERY) == 0:
stats = evaluate_policy(policy, seed=SEED + 999 + it, n_episodes=EVAL_EPISODES)
eval_iters.append(it + 1)
eval_return_mean.append(stats["return_mean"])
eval_success_rate.append(stats["success_rate"])
elapsed = time.time() - start
elapsed
3.971886157989502
7) Plotly diagnostics#
Required plots:
discriminator loss
policy learning (evaluation return + success rate)
episodic rewards (environment return per episode)
# Discriminator loss over iterations
df_disc = pd.DataFrame({
"iteration": np.arange(1, len(disc_loss_hist) + 1),
"disc_loss": disc_loss_hist,
})
fig = px.line(df_disc, x="iteration", y="disc_loss", title="Discriminator loss")
fig.update_layout(xaxis_title="iteration", yaxis_title="BCE loss")
fig.show()
# Policy learning: evaluation return + success rate
df_eval = pd.DataFrame({
"iteration": eval_iters,
"eval_return_mean": eval_return_mean,
"eval_success_rate": eval_success_rate,
})
fig = go.Figure()
fig.add_trace(go.Scatter(x=df_eval["iteration"], y=df_eval["eval_return_mean"], mode="lines+markers", name="eval return"))
fig.add_trace(go.Scatter(x=df_eval["iteration"], y=df_eval["eval_success_rate"], mode="lines+markers", name="success rate", yaxis="y2"))
fig.update_layout(
title="Policy learning (evaluation)",
xaxis=dict(title="iteration"),
yaxis=dict(title="mean episodic return"),
yaxis2=dict(title="success rate", overlaying="y", side="right", range=[0, 1]),
)
fig.show()
print("Expert baseline:", expert_eval)
Expert baseline: {'return_mean': 0.8834500104933977, 'return_std': 0.322598197569875, 'success_rate': 0.955, 'steps_mean': 7.155}
# Episodic rewards collected during training
if len(train_ep_returns) == 0:
print("No completed episodes recorded (increase STEPS_PER_ITER or MAX_STEPS).")
else:
df_ep = pd.DataFrame({
"episode": np.arange(1, len(train_ep_returns) + 1),
"return": train_ep_returns,
"success": train_ep_success,
"iteration": train_ep_iter,
})
fig = px.scatter(
df_ep,
x="episode",
y="return",
color="success",
title="Episodic returns during training (true env reward)",
opacity=0.6,
)
fig.update_layout(xaxis_title="episode", yaxis_title="episodic return")
fig.show()
window = 25
if len(df_ep) >= window:
ma = df_ep["return"].rolling(window=window).mean()
fig = go.Figure()
fig.add_trace(go.Scatter(x=df_ep["episode"], y=df_ep["return"], mode="markers", name="return", opacity=0.35))
fig.add_trace(go.Scatter(x=df_ep["episode"], y=ma, mode="lines", name=f"moving avg (window={window})"))
fig.update_layout(title="Episodic return with moving average", xaxis_title="episode", yaxis_title="return")
fig.show()
8) Visual sanity check: expert vs learned trajectories#
A qualitative check: roll out the expert and the learned policy from the same start state.
def rollout_single(policy: ActorCritic, seed: int, use_expert: bool) -> dict:
env = VectorPointNav2D(
n_envs=1,
max_steps=MAX_STEPS,
step_size=STEP_SIZE,
noise_std=NOISE_STD,
goal_radius=GOAL_RADIUS,
seed=seed,
)
obs = env.reset()
traj = [obs[0].copy()]
rewards = []
done = np.array([False])
while not done[0]:
if use_expert:
a = int(expert_policy(obs)[0])
else:
obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
logits, _ = policy.forward(obs_t)
a = int(torch.argmax(logits, dim=-1).cpu().item())
obs, r, done, info = env.step(np.array([a]))
traj.append(obs[0].copy())
rewards.append(float(r[0]))
return {
"traj": np.stack(traj),
"return": float(sum(rewards)),
"success": bool(info["success"][0]),
}
seed = SEED + 2025
expert_roll = rollout_single(policy, seed=seed, use_expert=True)
learned_roll = rollout_single(policy, seed=seed, use_expert=False)
fig = go.Figure()
fig.add_trace(go.Scatter(x=expert_roll["traj"][:, 0], y=expert_roll["traj"][:, 1], mode="lines+markers", name="expert"))
fig.add_trace(go.Scatter(x=learned_roll["traj"][:, 0], y=learned_roll["traj"][:, 1], mode="lines+markers", name="learned"))
fig.add_trace(go.Scatter(x=[0], y=[0], mode="markers", marker=dict(size=12, symbol="x"), name="goal"))
fig.update_layout(
title="Expert vs learned trajectory (same start)",
xaxis_title="x",
yaxis_title="y",
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.show()
print("expert", {k: expert_roll[k] for k in ["return", "success"]})
print("learned", {k: learned_roll[k] for k in ["return", "success"]})
expert {'return': 0.9600000102072954, 'success': True}
learned {'return': -0.5999999865889549, 'success': False}
9) Stable-Baselines GAIL (implementation exists) + hyperparameters#
Does Stable-Baselines implement GAIL?#
Yes: Stable-Baselines (the TensorFlow-based library, not SB3) ships a GAIL class. Upstream docs/source show:
stable_baselines.GAILexists and is TRPO-based (inherits fromTRPO)it expects an
ExpertDatasetit requires OpenMPI support (for the MPI-based TRPO implementation)
Example from upstream docs (Pendulum):
import gym
from stable_baselines import GAIL, SAC
from stable_baselines.gail import ExpertDataset, generate_expert_traj
# Generate expert trajectories (train expert)
model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
generate_expert_traj(model, 'expert_pendulum', n_timesteps=100, n_episodes=10)
# Load the expert dataset
dataset = ExpertDataset(expert_path='expert_pendulum.npz', traj_limitation=10, verbose=1)
model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1)
model.learn(total_timesteps=1000)
model.save("gail_pendulum")
Hyperparameters (Stable-Baselines GAIL)#
From the Stable-Baselines GAIL docstring/source, the key knobs are:
TRPO / policy optimization (inherited)
gamma: discount factortimesteps_per_batch: rollout horizon per TRPO batchmax_kl: KL constraint threshold (trust-region size)cg_iters: conjugate-gradient iterations (for TRPO step)lam: GAE((\lambda))entcoeff: entropy regularization coefficientcg_damping: damping for conjugate gradient / Fisher-vector productsvf_stepsize: value function optimizer step sizevf_iters: value function training iterations per updatehidden_size: MLP hidden sizes for the policy/value network
GAIL-specific (how often/fast to train each player)
g_step: number of generator/policy steps per epochd_step: number of discriminator steps per epochd_stepsize: discriminator/reward-giver learning ratehidden_size_adversary: discriminator hidden sizeadversary_entcoeff: entropy term used in the adversary loss (stabilization)
Notes:
If
d_stepis too large (ord_stepsizetoo high), the discriminator can become too strong → sparse/unstable rewards.If
g_stepis too large with a weak discriminator, the policy can overfit to a stale reward signal.
SB3 note#
Stable-Baselines3 does not ship GAIL in core; in practice, people often use the separate imitation library (HumanCompatibleAI) with SB3 policies.
Exercises#
Replace PPO with a TRPO-style update (harder, but closer to the original paper).
Add reward normalization (as Stable-Baselines’ adversary optionally does) and see how curves change.
Make the environment continuous-action and switch the policy to a Gaussian distribution.
References#
Ho & Ermon (2016), Generative Adversarial Imitation Learning: https://arxiv.org/abs/1606.03476
Stable-Baselines GAIL docs: https://stable-baselines.readthedocs.io/en/master/modules/gail.html
Stable-Baselines source (
stable_baselines/gail/model.py): https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/gail/model.py